
package com.iwuyc.tools.commons.structs.collections;

import com.iwuyc.tools.commons.structs.collections.NavigateTreeNode;
import com.iwuyc.tools.commons.util.collection.CollectionUtil;
import lombok.Getter;
import lombok.ToString;

import java.io.Serializable;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.IntConsumer;

/**
 * 可快速访问的树形接口
 *
 * @author Neil
 * @version 1.0.0
 * @since 2022.1
 */
@ToString
public class NavigateTree<K extends Serializable, V extends Serializable> extends HashMap<K, NavigateTreeNode<K, V>> implements Serializable {
    private static final long serialVersionUID = -1475220629986796510L;

    private final Collection<K> rootNodePrimaryKeys = new HashSet<>();

    public NavigateTree(int capacitySize) {
        super(capacitySize);
    }

    public NavigateTree() {
    }

    public static <K extends Serializable, V extends Serializable> NavigateTree<K, V> empty() {
        return new NavigateTree<>(0);
    }

    private static <K extends Serializable, V extends Serializable> void scoreSetterPush(final Deque<ScoreSetter<K, V>> stack, final NavigateTreeNode<K, V> node) {
        stack.push(ScoreSetter.createRight(node));
        stack.push(ScoreSetter.createLeft(node));
    }

    public void add(NavigateTreeNode<K, V> node) {

        final Deque<NavigateTreeNode<K, V>> nodeStack = new ArrayDeque<>();
        nodeStack.push(node);
        while (!nodeStack.isEmpty()) {
            final NavigateTreeNode<K, V> item = nodeStack.pop();
            this.add(item.getParentKey(), item.getPrimaryKey(), item.getVal());
            if (item.hasChildren()) {
                nodeStack.addAll(item.getChildren().values());
            }
        }
    }

    /**
     * Add root node into navigate tree
     *
     * @param primaryKey primary key
     * @param val        data
     */
    public void add(K primaryKey, V val) {
        this.add(null, primaryKey, val);
    }

    /**
     * Add normal node or root node(if parent key was null) into navigate tree
     *
     * @param parentKey  parent key,the primary key was root key if null.
     * @param primaryKey primary key
     * @param val        data
     */
    public void add(K parentKey, K primaryKey, V val) {
        checkCircular(parentKey, val);

        final NavigateTreeNode<K, V> currentNode = super.computeIfAbsent(primaryKey, key -> NavigateTreeNode.createNode(parentKey, primaryKey));
        currentNode.setVal(val);

        if (null == parentKey) {
            rootNodePrimaryKeys.add(currentNode.getPrimaryKey());
        } else {
            final NavigateTreeNode<K, V> parenNode = super.get(parentKey);
            if (null != parenNode) {
                parenNode.addChild(currentNode);
            } else {
                // create virtual node for parent and put into root nodes
                final NavigateTreeNode<K, V> virtualParentNode = NavigateTreeNode.createNode(null, parentKey);
                super.put(parentKey, virtualParentNode);
                virtualParentNode.addChild(currentNode);
                rootNodePrimaryKeys.add(virtualParentNode.getPrimaryKey());
                rootNodePrimaryKeys.remove(currentNode.getPrimaryKey());
            }
        }
    }

    private void checkCircular(K parentKey, V val) {
        K innerParentKey = parentKey;
        while (innerParentKey != null) {
            final Optional<NavigateTreeNode<K, V>> parentNodeOpt = this.get(innerParentKey);
            if (!parentNodeOpt.isPresent()) {
                return;
            }
            final NavigateTreeNode<K, V> parentNode = parentNodeOpt.get();
            final V parentVal = parentNode.getVal();
            if (parentVal == val) {
                throw new IllegalArgumentException("Circular reference");
            }
            innerParentKey = parentNode.getParentKey();
        }

    }

    public boolean remove(K primaryKey) {
        NavigateTreeNode<K, V> removeNode = super.remove(primaryKey);
        if (null == removeNode) {
            return false;
        }
        this.rootNodePrimaryKeys.remove(removeNode.getPrimaryKey());
        // remove all child from this container
        for (K childKey : removeNode.getAllChildrenPrimaryKey()) {
            super.remove(childKey);
        }

        return true;
    }

    public Optional<NavigateTreeNode<K, V>> get(K primaryKey) {
        return Optional.ofNullable(super.get(primaryKey));
    }

    public void addAll(NavigateTree<K, V> otherNavigateTree) {
        if (null == otherNavigateTree) {
            return;
        }
        this.addAll(otherNavigateTree.values());
    }

    public void addAll(Collection<NavigateTreeNode<K, V>> navigateTreeNodes) {
        if (CollectionUtil.isEmpty(navigateTreeNodes)) {
            return;
        }
        for (NavigateTreeNode<K, V> otherItem : navigateTreeNodes) {
            this.add(otherItem);
        }
    }

    /**
     * 左右节点分数计算，只计算root节点下的所有子节点
     *
     * @param scoreStart 分数起始值
     */
    public void scoreCalc(final int scoreStart) {
        this.scoreCalc(null, scoreStart);
    }

    /**
     * 左右节点分数计算，只计算root节点下的所有子节点
     *
     * @param sorted     兄弟节点(同一父结点的值)排序规则
     * @param scoreStart 分数起始值
     */
    public void scoreCalc(final Comparator<NavigateTreeNode<K, V>> sorted, int scoreStart) {
        for (NavigateTreeNode<K, V> rootNode : this.getRootNodes()) {
            this.scoreCalc(rootNode, sorted, scoreStart);
        }
    }

    /**
     * 左右节点分数计算
     *
     * @param rootNode   根节点实例
     * @param sorted     排序规则
     * @param scoreStart 分数起始点
     */
    private void scoreCalc(NavigateTreeNode<K, V> rootNode, Comparator<NavigateTreeNode<K, V>> sorted, int scoreStart) {
        final Deque<ScoreSetter<K, V>> stack = new ArrayDeque<>();
        scoreSetterPush(stack, rootNode);
        int counter = scoreStart;
        while (!stack.isEmpty()) {
            final ScoreSetter<K, V> scoreSetter = stack.pop();
            scoreSetter.setScore(counter++);
            if (!scoreSetter.hasChildren()) {
                continue;
            }

            final NavigateTreeNode<K, V> current = scoreSetter.getNode();
            final Map<K, NavigateTreeNode<K, V>> childrenByKey = current.getChildren();

            final Collection<NavigateTreeNode<K, V>> children;
            if (null != sorted) {
                children = new ArrayList<>(childrenByKey.values());
                // 优先级高的要先处理，因此，要将提供的comparator进行reversed
                ((List<NavigateTreeNode<K, V>>) children).sort(sorted.reversed());
            } else {
                children = childrenByKey.values();
            }

            for (NavigateTreeNode<K, V> child : children) {
                scoreSetterPush(stack, child);
            }
        }
    }

    /**
     * 获取根节点列表
     *
     * @return 所有的根节点列表
     */
    public Collection<NavigateTreeNode<K, V>> getRootNodes() {
        if (CollectionUtil.isEmpty(rootNodePrimaryKeys)) {
            return Collections.emptySet();
        }
        final int size = rootNodePrimaryKeys.size();
        final Collection<NavigateTreeNode<K, V>> result = new HashSet<>(size);
        for (K k : rootNodePrimaryKeys) {
            final NavigateTreeNode<K, V> node = super.get(k);
            if (null == node) {
                continue;
            }
            result.add(node);
        }
        return result;
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || getClass() != o.getClass()) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        NavigateTree<?, ?> that = (NavigateTree<?, ?>) o;
        return Objects.equals(rootNodePrimaryKeys, that.rootNodePrimaryKeys);
    }

    @Override
    public int hashCode() {
        return Objects.hash(super.hashCode(), rootNodePrimaryKeys);
    }

    private static class ScoreSetter<K extends Serializable, V extends Serializable> {
        private final boolean left;
        private final IntConsumer scoreSetterMethod;
        @Getter
        private final NavigateTreeNode<K, V> node;

        private ScoreSetter(boolean left, IntConsumer scoreSetterMethod, NavigateTreeNode<K, V> node) {
            this.left = left;
            this.scoreSetterMethod = scoreSetterMethod;
            this.node = node;
        }

        public static <K extends Serializable, V extends Serializable> ScoreSetter<K, V> createLeft(NavigateTreeNode<K, V> node) {
            return new ScoreSetter<>(true, node::setLeft, node);
        }

        public static <K extends Serializable, V extends Serializable> ScoreSetter<K, V> createRight(NavigateTreeNode<K, V> node) {
            return new ScoreSetter<>(false, node::setRight, node);
        }

        public void setScore(int score) {
            this.scoreSetterMethod.accept(score);
        }

        public boolean hasChildren() {
            return this.left && this.node != null && this.node.hasChildren();
        }
    }
}
